{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jUg9XQnOI5bc"
      },
      "source": [
        "## Pseudo code for computing relative Mahlanobis distance\n",
        "\n",
        "*Licensed under the Apache License, Version 2.0.*\n",
        "\n",
        "To run this in a public Colab, change the GitHub link: replace github.com with [githubtocolab.com](http://githubtocolab.com)\n",
        "\n",
        "\u003ca href=\"https://githubtocolab.com/google/uncertainty-baselines/blob/main/experimental/ood_clm/Pseudo_code_for_computing_relative_Mahalanobis_distance.ipynb\" target=\"_parent\"\u003e\u003cimg src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/\u003e\u003c/a\u003e\n",
        "\n",
        "This notebook demonstrates how to compute the relative Mahalanobis distance (RMD) used in\n",
        "[Out-of-Distribution Detection and Selective Generation for Conditional Language Models](https://arxiv.org/abs/2209.15558) for out-of-distributino detection for conditional language models.\n",
        "The RMD score is shown to be a highly accurate and lightweight OOD detection method for CLMs, as demonstrated on abstractive summarization and translation."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "iX41yOwnHL4m"
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "import sklearn.metrics\n",
        "import matplotlib.pyplot as plt\n",
        "\n",
        "\n",
        "import ood_utils  # local file import from baselines.jft"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ey44b79vA1et"
      },
      "source": [
        "## Steps for computing Relative Mahalanobis distance (RMD) OOD score"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zsJsJngE_6Yw"
      },
      "outputs": [],
      "source": [
        "# (1) Prepare feature embeddings, embs_train_ind (NxD), for in-domain training data\n",
        "\n",
        "# (2) Prepare the same number of feature embeddings, embs_train_ood (NxD), for\n",
        "# general domain data (e.g. C4 for summarization, or ParaGrawl for translation).\n",
        "\n",
        "# Here we create dummy values for feature embedings\n",
        "N = 1000 # sample size\n",
        "D = 256 # embedding dimension\n",
        "mu_ind = np.zeros(D)\n",
        "mu_ood = np.ones(D)*0.4\n",
        "sigma = np.identity(D)\n",
        "embs_train_ind = np.random.multivariate_normal(mu_ind, sigma, N)\n",
        "embs_train_ood = np.random.multivariate_normal(mu_ood, sigma, N)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "EJ_yyaMLAyXx"
      },
      "outputs": [],
      "source": [
        "# (3) Prepare feature embeddings, embs_ind and embs_ood for the test in-domain\n",
        "# and test OOD data\n",
        "\n",
        "# Here we create dummy values for feature embedings\n",
        "embs_test_ind = np.random.multivariate_normal(mu_ind, sigma, N)\n",
        "embs_test_ood = np.random.multivariate_normal(mu_ood, sigma, N)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3AMc3TTBQ8pK"
      },
      "outputs": [],
      "source": [
        "# (4) Fit Gaussian distributions for in-domain and general domain respectively\n",
        "mean_list, cov = ood_utils.compute_mean_and_cov(embs_train_ind, np.zeros(N), np.zeros(1))\n",
        "mean_list0, cov0 = ood_utils.compute_mean_and_cov(embs_train_ood, np.zeros(N), np.zeros(1))\n",
        "\n",
        "# (5) Compute RMD OOD score for the test examples dist - dist_0\n",
        "\n",
        "# Raw Mahalanobis distance\n",
        "md_ind = ood_utils.compute_mahalanobis_distance(embs_test_ind, mean_list, cov).reshape(-1)\n",
        "md_ood = ood_utils.compute_mahalanobis_distance(embs_test_ood, mean_list, cov).reshape(-1)\n",
        "\n",
        "md0_ind = ood_utils.compute_mahalanobis_distance(embs_test_ind, mean_list0, cov0).reshape(-1)\n",
        "md0_ood = ood_utils.compute_mahalanobis_distance(embs_test_ood, mean_list0, cov0).reshape(-1)\n",
        "\n",
        "# Relative Mahalnobis distance\n",
        "rmd_ind, rmd_ood = [md_ind - md0_ind, md_ood - md0_ood]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "IV41iUDdDR4N"
      },
      "outputs": [],
      "source": [
        "# (6) Compute AUROC for OOD detection\n",
        "fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(12, 5))\n",
        "plt_n = 0\n",
        "scores = {'Mahalanobis distance': (md_ind, md_ood), 'Relative Mahalanobis distance': (rmd_ind, rmd_ood)}\n",
        "for score_name, (scores_ind, scores_ood) in scores.items():\n",
        "  labels_ind, labels_ood = [np.zeros_like(scores_ind), np.ones_like(scores_ood)]\n",
        "  auc_rmd = ood_utils.compute_ood_metrics(np.concatenate((labels_ind, labels_ood)), np.concatenate((scores_ind, scores_ood)))\n",
        "  print(f'OOD metrics based on {score_name}', auc_rmd)\n",
        "  ax[plt_n].hist([scores_ind, scores_ood], label=('IND', 'OOD'), bins=20)\n",
        "  auroc = auc_rmd['auroc']\n",
        "  ax[plt_n].title.set_text(f'{score_name}: {auroc:.4f}')\n",
        "  ax[plt_n].legend()\n",
        "  plt_n += 1\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4fag0M3nJ-Gk"
      },
      "outputs": [],
      "source": []
    }
  ],
  "metadata": {
    "colab": {
      "last_runtime": {
        "build_target": "//learning/grp/tools/ml_python:ml_notebook",
        "kind": "private"
      },
      "private_outputs": true,
      "provenance": [
        {
          "file_id": "19P8_nAmweaHLE5u7Gf71NQI-cWsYi2ss",
          "timestamp": 1689098876460
        },
        {
          "file_id": "/piper/depot/google3/third_party/py/uncertainty_baselines/experimental/ood_lm/Pseudo_code_for_computing_relative_Mahalanobis_distance.ipynb?workspaceId=jjren:ood-clm::citc",
          "timestamp": 1689031743473
        },
        {
          "file_id": "/piper/depot/google3/third_party/py/uncertainty_baselines/experimental/ood_clm/Pseudo_code_for_computing_relative_Mahalanobis_distance.ipynb?workspaceId=jjren:ood-clm::citc",
          "timestamp": 1689031343152
        },
        {
          "file_id": "1k-z29GHaqtE7rw5TlV3l7bz7glu9zuqj",
          "timestamp": 1689007538570
        }
      ]
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
